import pandas as pd
import math
from datetime import datetime
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from bots.botlibs.oscillators_labeling import *
from bots.botlibs.tester_lib import test_model
import random
import itertools

def get_prices() -> pd.DataFrame:
    p = pd.read_csv('files/'+hyper_params['symbol']+'.csv', sep='\s+')
    pFixed = pd.DataFrame(columns=['time', 'close'])
    pFixed['time'] = p['<DATE>'] + ' ' + p['<TIME>']
    pFixed['time'] = pd.to_datetime(pFixed['time'], format='mixed')
    pFixed['close'] = p['<CLOSE>']
    pFixed.set_index('time', inplace=True)
    pFixed.index = pd.to_datetime(pFixed.index, unit='s')
    return pFixed.dropna()

def get_features(data: pd.DataFrame) -> pd.DataFrame:
    pFixed = data.copy()
    pFixedC = data.copy()
    count = 0

    for i in hyper_params['periods']:
        pFixed[str(count)] = pFixedC -  pFixedC.rolling(i).mean()
        count += 1

    return pFixed.dropna()

def get_labels(dataset, min = 1, max = 15) -> pd.DataFrame:
    labels = []
    for i in range(dataset.shape[0]-max):
        rand = random.randint(min, max)
        curr_pr = dataset['close'].iloc[i]
        future_pr = dataset['close'].iloc[i + rand]

        if (future_pr + hyper_params['markup']) < curr_pr:
            labels.append(1.0)
        elif (future_pr - hyper_params['markup']) > curr_pr:
            labels.append(0.0)
        else:
            labels.append(2.0)
        
    dataset = dataset.iloc[:len(labels)].copy()
    dataset['labels'] = labels
    dataset = dataset.dropna()
    return dataset

# def get_features(data: pd.DataFrame) -> pd.DataFrame:
#     pFixed = data.copy()
#     count = 0

#     for period in hyper_params['periods']:
#         delta = data.diff()
#         gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
#         loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
#         rs = gain / loss
#         pFixed[str(count)] = 100 - (100 / (1 + rs))
#         count += 1
#     return pFixed.dropna()

# def get_features(data: pd.DataFrame) -> pd.DataFrame:
#     pFixed = data.copy()
#     count = 0

#     for period in hyper_params['periods']:
#         delta = data.diff()
#         gain = delta.where(delta > 0, 0)
#         loss = -delta.where(delta < 0, 0)
        
#         avg_gain = gain.ewm(alpha=1/period, adjust=False).mean()
#         avg_loss = loss.ewm(alpha=1/period, adjust=False).mean()
        
#         rs = avg_gain / avg_loss
#         pFixed[str(count)] = 100 - (100 / (1 + rs))
#         count += 1

#     return pFixed.dropna()

def fit_final_models(dataset) -> list:
    # features for model\meta models. We learn main model only on filtered labels 
    X, X_meta = dataset[dataset['meta_labels']==1], dataset[dataset.columns[1:-2]]
    X = X[X.columns[1:-2]]
    
    # labels for model\meta models
    y, y_meta = dataset[dataset['meta_labels']==1], dataset[dataset.columns[-1]]
    y = y[y.columns[-2]]
    
    y = y.astype('int16')
    y_meta = y_meta.astype('int16')

    # train\test split
    train_X, test_X, train_y, test_y = train_test_split(
        X, y, train_size=0.8, test_size=0.2, shuffle=True)
    
    train_X_m, test_X_m, train_y_m, test_y_m = train_test_split(
        X_meta, y_meta, train_size=0.8, test_size=0.2, shuffle=True)

    # learn main model with train and validation subsets
    model = CatBoostClassifier(iterations=1000,
                               learning_rate=0.1,
                               custom_loss=['F1'],
                               eval_metric='F1',
                               verbose=False,
                               use_best_model=True,
                               task_type='CPU')
    model.fit(train_X, train_y, eval_set=(test_X, test_y),
              early_stopping_rounds=25, plot=False)
    
    # learn meta model with train and validation subsets
    meta_model = CatBoostClassifier(iterations=1000,
                                    learning_rate=0.1,
                                    custom_loss=['F1'],
                                    eval_metric='F1',
                                    verbose=False,
                                    use_best_model=True,
                                    task_type='CPU')
    meta_model.fit(train_X_m, train_y_m, eval_set=(test_X_m, test_y_m),
              early_stopping_rounds=25, plot=False)
    data = get_features(get_prices())
    R2 = test_model(data, 
                    [model, meta_model], 
                    hyper_params['stop_loss'], 
                    hyper_params['take_profit'],
                    hyper_params['forward'],
                    hyper_params['backward'],
                    hyper_params['markup'],
                    plt=False)
    
    if math.isnan(R2):
        R2 = -1.0
        print('R2 is fixed to -1.0')
    print('R2: ' + str(R2))
    models = [R2, model, meta_model]
    return models

def export_model_to_ONNX(model, model_number):
    model[1].save_model(
    hyper_params['export_path'] +'catmodel' + str(model_number) +'.onnx',
    format="onnx",
    export_parameters={
        'onnx_domain': 'ai.catboost',
        'onnx_model_version': 1,
        'onnx_doc_string': 'test model for BinaryClassification',
        'onnx_graph_name': 'CatBoostModel_for_BinaryClassification'
    },
    pool=None)

    model[2].save_model(
    hyper_params['export_path'] + 'catmodel_m' + str(model_number) +'.onnx',
    format="onnx",
    export_parameters={
        'onnx_domain': 'ai.catboost',
        'onnx_model_version': 1,
        'onnx_doc_string': 'test model for BinaryClassification',
        'onnx_graph_name': 'CatBoostModel_for_BinaryClassification'
    },
    pool=None)
    
    code = '#include <Math\Stat\Math.mqh>'
    code += '\n'
    code += '#resource "catmodel'+str(model_number)+'.onnx" as uchar ExtModel[]'
    code += '\n'
    code += '#resource "catmodel_m'+str(model_number)+'.onnx" as uchar ExtModel2[]'
    code += '\n'
    code += 'int Periods' + '[' + str(len(hyper_params['periods'])) + \
        '] = {' + ','.join(map(str, hyper_params['periods'])) + '};'
    code += '\n\n'

    # get features
    code += 'void fill_arays' + '( double &features[]) {\n'
    code += '   double pr[], ret[];\n'
    code += '   ArrayResize(ret, 1);\n'
    code += '   for(int i=ArraySize(Periods'')-1; i>=0; i--) {\n'
    code += '       CopyClose(NULL,PERIOD_CURRENT,1,Periods''[i],pr);\n'
    code += '       ret[0] = pr[ArraySize(pr)-1] - MathMean(pr);\n'
    code += '       ArrayInsert(features, ret, ArraySize(features), 0, WHOLE_ARRAY); }\n'
    code += '   ArraySetAsSeries(features, true);\n'
    code += '}\n\n'

    file = open(hyper_params['export_path'] + str(hyper_params['symbol']) + ' ONNX include' + str(model_number) + '.mqh', "w")
    file.write(code)

    file.close()
    print('The file ' + 'ONNX include' + '.mqh ' + 'has been written to disk')

hyper_params = {
    'symbol': 'AUDCAD_H1',
    'export_path': '/Users/dmitrievsky/Library/Application Support/net.metaquotes.wine.metatrader5/drive_c/Program Files/MetaTrader 5/MQL5/Include/Mean reversion/',
    'model_number': 0,
    'markup': 0.00010,
    'stop_loss':  0.01000,
    'take_profit': 0.01000,
    'periods': [i for i in range(5, 50, 5)],
    'backward': datetime(2010, 1, 1),
    'forward': datetime(2024, 1, 1),
}


options = []
for i in range(5):
    print('Learn ' + str(i) + ' model')
    dataset = get_features(get_prices())
    # dataset = get_labels_cci(dataset, cci_period=50, 
    #                          oversold_level=-130, overbought_level=130)
    # dataset = get_labels_stochastic(dataset, stoch_period=30, smooth_k=15,
    #                                 oversold_level=10, overbought_level=90)
    # dataset = get_labels_bb(dataset, bb_period=25, num_std=2)
    dataset = get_labels_rsi(dataset, rsi_period=9,
                             oversold_level=30, overbought_level=70)
    # dataset = get_labels_fourier(dataset, lookback_period=100, high_pass_cutoff_idx=5,
    #                    std_multiplier=1.5)
    # dataset = get_labels_profit_rsi_profit_check(dataset, rsi_period=3,
    #                         oversold_level=30.0, overbought_level=70.0,
    #                         min_forecast_period=1, max_forecast_period=15,
    #                         markup=hyper_params['markup'])
    # dataset = get_labels(dataset)
    
    
    dataset['meta_labels'] = (dataset['labels'] != 2.0).astype(float)
    data = dataset[(dataset.index < hyper_params['forward']) & (dataset.index > hyper_params['backward'])].copy()
    options.append(fit_final_models(data))

options.sort(key=lambda x: x[0])
data = get_features(get_prices())
test_model(data, 
        options[-1][1:], 
        hyper_params['stop_loss'], 
        hyper_params['take_profit'],
        hyper_params['forward'],
        hyper_params['backward'],
        hyper_params['markup'],
        plt=True)

options[-1][2].get_best_score()['validation']

export_model_to_ONNX(options[-1],0)
